Historical Panchromatic Orthophoto Colorisation with a Generative Adversarial Neural net¶

By Kay Warrie

This is deeplearing model to colorize historical greyscale or panchromatic orthophotos and orthophoto mosaics. Greyscale images that made using all the wavelengths of the visible spectrum are called panchromatic, most historical images are panchrommatic.

An orthophoto is an aerial photograph geometrically corrected ("orthorectified") such that the scale is uniform. It is the basis for most mapping solutions.

An orthophoto mosaic is a type of large scale image that is created by stitching together a collection orthophoto to produce a seamless, georeferenced image, for example "Satelite"-view in google maps, that is by thet way, mostly made with aerial photo's and not with satelite images.

A Generative Adversarial Network (GAN) is a type of artificial intelligence, a generative model consisting of two neural networks, the generator and the discriminator.
The generator is convolutional neural net that makes an image and the discriminator is en model the tries to distinguish between the label data and the generated images.
The Loss of GAN's discriminator calculated by passing it a batch of the generators output and a batch of real data and seeing if it can distinguish between the two. The loss of generator is output of the discriminator.

A full explantation how this model was constructed can found in explanation.ipynb.

To train and run the final a series of commandline tools was constructed:

  • pretrain_unet.py -> initialise the U-net generator.
  • trainWeigthed.py inference
  • inference.py -> test the resulting model on real greyscale images.

To showcase the inference results a interactive webpage was constructed:

In [6]:
%%html
<iframe width="100%" height="500" style="border:none; overflow-x: hidden;"
src="https://warrieka.github.io/histo_ortho_viewer/?hidebanner=1" ></iframe>

Context and background:¶

There many historal panchromatic aerial photo's of Belgium, like those made by the National Mapping Agency NGI and the Allies and Axis aerial reconnaissance forces during WOII. Also the older private remote sensing operators like the Belgian company Eurosense have large collection of historical data.

These are stil in active use, like for tracing the history of contruction projects, track building vialations or for historical and archeological research or just for communication an illustration purposes. These are mostly used as they are without georerencing or mosaicing of color optimisations. This is not ideal for interpretaration purposes as you cannot overlay other mapping data on these photo's. Some effords were made create mosaics of these photo's, like:

  • The 1971 panchromatic orthomosiac of Flanders made form a series of photo's flown by Eurosense for the Flemish goverment. -
  • The 1955 orthomosaic of tbe city of Ghent by made the team city-archeology and Team data of the city Ghent derived from a "forgotten" photo colection found in the archives the departement of the public works of the Flemish government.
  • The 1940-1940 orthomosaic of Antwerp derived from a heterogeneous collection of Allies and Axis aerial reconnaissance forces photo's. This series was collected as source material for the book Vergeten Linies 3 : Militair Erfgoed Binnen de Antwerpse Fortengordels Op Luchtfoto en Lidar. By prof. Gheyle Wouter, and Ignace Bourgeois, published by Provincie Antwerpen, 2018.. The processing of the data was done by the city of Antwerp, a lot procssing was needed to match and improved colors, remove clouds and artefacts and match resolutions.

Goals:¶

None of the previously mentioned mosaics was colorized and are only available in grayscale.
THe goal of this project is to create a tool to colorize these kind of mosaics while preserving resoltion and geographical metadata.

Most automated colorsiation algorithms are based on a convolutional neural networks (CNN), originally traditional models where used later these models included a Generative Adversarial Neural (GAN) trainig phase.

Some okder approaches include like "Colorful Image Colorization" by Ricard Zhnag (2016). This was a regular classifiaction CNN, no GAN just yet, so it could only hand a limited ammount of features.

The most succesfull approach has been DeOldify by Jason Antic (2018). This is the model hat is the basis for most commercial colorising software today.

Some newer models models are also based on image to image diffusion models like controlnet by Lvmin Zhang (2023) build on top of stable diffusion, but these have huge GPU demands, have big problem of hallucinating new information and destroying existing data. So these are not suited for our purposes.

All of these network's have been trained on "normal" photo's and not on aerial imagery and tend to perfom poorly when reconstructing a orthophoto. I'll need to train our own network.

Similar projects¶

I did some further research and found some similar projects that translate one type of geospatial imagery to fake orthophoto's.

  • In this example from ESRI they generate fake a orthophoto from elevation-data: "Generating rgb imagery from surface elevation using Pix2Pix"
  • This person made fake airial photo's from 19-century ordnance survey maps: "map2sat: Satellite Image Generation Conditioned on Maps, Generate Your Own Scotland" by Miguel Espinosa Et al.
  • Both articles are based the original paper by Phillip Isola Et al. (2016): "Image-to-Image Translation with Conditional Adversarial Networks".

In this approach ESRI uses a CycleGAN instead to translate SAR image (in radio wave part ot the spectum) into colorised photos, this so this is explicily trained with unpaired data, forcing the model to find a broad "style" in the source data and transfer this to the target:

  • "SAR to RGB image translation using CycleGAN"
  • This article was based on "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks" by Jun-Yan Zhu (2017).

Final approach¶

While these approaches are valid, one of the advantages of greyscale imagery is that you already have a part of the color, in CIELab-colorspacen the L-component is identical to the image in greyscale, so you only need to predict 2 values instead of 3 like when you ate translating a map or diffrent-part spectrum to an RGB-image. Ligthness (L) the value of a pixel in Grey is composed of the original RGB-values according to this formula:

$$ L = 0.30 \times R + 0.59 \times G + 0.11 \times B $$

I found this article "Colorizing black & white images with U-Net and conditional GAN" by Moein Shariatnia Published on November 18 2020 in Towards Data Science. In this arcticle he outlines how to predict the ab-values of greyscale image in CIELAB colorspace and then recompose it with the original image to create an RBG-image. For the predications he uses a variation of the U-net classification model, but modified to output 2 channels. He mostly used pytorch and torchvision and fast.ai for the implementation of his model he also uses scikit-image for colorspace manipulations, as torchvision is lacking in this regard.

I largly copied his code but I made several changes to fit it to training en infering on older geospatial orthophoto's, like a specific augmentation function and reading data with GDAL a library that preserves geospacial metadata when reading data and offers several utilities to deal with geodata, unlike PIL, torchvision or OpenCV.

The remainder of the code in this notebook goes though all the steps of creating this model. These are bit simplified over the full implementation, but are fully functional and will create model when, just not a very good one. If you want to try this yourselves you will need to donwload the source data as described in chapter Datasources.

Install the necessary libraries if needed¶

REMARK: check https://pytorch.org for the best option for your system, Also make a separate environment for your system.

To make an environment you can run the following in powershell:

python -m venv ortho_env
./Scripts/Activate.ps1

Then install the dependcies.

Alternatively you can also use docker to isolate you code, a Docker file is provided.

In [ ]:
%pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
%pip install fastai
%pip install GDAL
%pip install timm
%pip install scikit-image
%pip install tqdm, IProgress 
%pip install pillow
%pip install matplotlib
%pip install pandas

import all the libraries¶

  • Use pytorch to create the neural net
  • numpy for array's, pandas for tables and matplotlib for ploting
  • GDAL is uaed to read geospatial data
  • skimage for processing images
In [7]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

import pandas as pd
import numpy as np

from osgeo import gdal
gdal.UseExceptions() # needs to be called, so gdal will have readable exceptions

import matplotlib.pyplot as plt

from skimage.color import rgb2lab, lab2rgb, rgb2gray
from skimage.io import imread

import time, datetime,  os.path
from pathlib import Path
from typing import Iterable

from tqdm.notebook import tqdm
In [6]:
%matplotlib inline 
import warnings
warnings.simplefilter("ignore", category=RuntimeWarning)
warnings.simplefilter('ignore', category=UserWarning)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEBUG = True

Preprocessing¶

Obtaining data¶

Datasources¶

  • RGB-colored Aerial photo's for training (Open data, from the Flemish Government):
    • 2023: Meest recente beelden op moment van schhrijven https://download.vlaanderen.be/product/10426-orthofotomoza%C3%AFek-middenschalig-winteropnamen-kleur-2023-vlaanderen
    • 2015: De beste resolutie beschikbaar https://download.vlaanderen.be/product/864-grootschalige_orthofotomozaieken
    • 1979-1990: De oudste kluerenbeelden van Vlaanderen https://download.vlaanderen.be/product/602-orthofotomoza%C3%AFek-kleinschalig-zomeropnamen-kleur-1979-1990-vlaanderen
  • Potential sources for real panchromatic aerial photo's for testing:
    • NGI: http://www.cartesius.be/CartesiusPortal/
    • Digitaal Vlaanderen: https://www.vlaanderen.be/datavindplaats/catalogus/orthofotomozaiek-kleinschalig-zomeropnamen-0
    • City of Gent: https://stad.gent/nl/cultuur-vrije-tijd/cultuur/hoe-zag-jouw-buurt-eruit-de-jaren-50
    • Panchromatic photos from the internal collections of the city of Antwerp https://felixarchief.antwerpen.be/archievenoverzicht/168417.
  • Landuse: https://www.vlaanderen.be/datavindplaats/catalogus/bodembedekkingskaart-bbk-1m-resolutie-opname-2018

Tiling the sourcedata from JPEG2000 to Jpeg¶

Large scale mozaic's aree usualy stored as Jpeg2000 an extension of the jpg-format that support internal tiling and wavelet compression, torchvision can't read jpeg2000 and the files are to large to process as a single array.

GDAL has a fast tool to https://gdal.org/programs/gdal_retile.html this tool preserves the geographical properties and metadata in a csv-file, and allows also to write some overlap, this is nessary if you gonne apply convoltions on the data. Files with only nodata are ommitted.

For each jpeg2000 write a collection of tiles and index csv-file that describes these files. For example for the files of 2023, in powershell:

foreach ($G in ls W:\2023\*.jp2 ){ 
    gdal_retile -co WORLDFILE=YES -overlap 100 -ps 512 512 -csv "W:\\2023_tiles\\$($G.Basename).csv" -f JPEG -ot Byte -targetDir "W:\\2023_tiles" $G.fullname
}

For the 0.15 meter resolution photo's of 2023 this gave me 2626657 files.

The making of a weigthed dataset¶

I will use the Flemish landuse dataset 'Bodem Bedekkings Kaart (BBK)" from the Flemish Goverment:
https://www.vlaanderen.be/datavindplaats/catalogus/bodembedekkingskaart-bbk-1m-resolutie-opname-2018

The BBK is a Segmentation map derived from the Multispectral (RGB-NIR) aerial images that shows the land cover in Flanders. Based on pixel classification supplemented with vectorial ground truth data, data is divided into low green, agricultural, forest, waterway and buildings. The target audience of these maps at resolution 1m and 5m is the general user who wants to consult a land cover map of Flanders as a basis for various analyzes relating to land cover or land use.

This map has 14 broad catergories + 0 for no data (usually not in Flanders):

  1. Buildings
  2. Roads
  3. Other constructed
  4. Railroads
  5. Water
  6. Other Natural
  7. Field - Agriculture
  8. Grass - Bushes (Nature)
  9. Trees (Nature)
  10. Grass Bushes (Agriculture)
  11. Grass Bushes (Road Edge)
  12. Trees (Road edge)
  13. Grass Bushes (Water Edge)
  14. Trees (Water Edge)

A colored image of the BBK

We sample for the extend of every photo for the dominant Category. We use the tool Zonal statistics is QGIS for this: https://docs.qgis.org/3.34/en/docs/user_manual/processing_algs/qgis/rasteranalysis.html#qgiszonalstatisticsfb

The result is saved to arrow-file.

Below you you see this file on of the traindata yellow for agriculture types, green for natural types, blue for water and reds for constructed types.

A colored image of the BBK

In [5]:
BBK_CATEGORIES = {
    0: "Out of scope, onbekend",
    1: "Gebouwen",
    2: "Autowegen",
    3: "Overig Afgedekt",
    4: "Spoorwegen",
    5: "Water",
    6: "Overig Onafgedekt",
    7: "Akker (Landbouw)",
    8: "Gras  Struiken (Groen)",
    9: "Bomen (Groen)",
    10: "Gras  Struiken (Landbouw)",
    11: "Gras  Struiken (Wegrand)" ,
    12: "Bomen (Wegrand)" ,
    13: "Gras  Struiken (Waterrand)",
    14: "Bomen (Waterrand)"}

grouped_categories=  {"UNKNOWN": [0], "BUILDING": [1], "ROADLIKE":[2,3,4], "GREEN": [6,8,10,11], 
                      "AGRO": [7], "WOOD": [9,12], "WATER":[5,13,14] }
In [6]:
ds = pd.read_feather(r"W:\1989_tiles\index_landuse.arrow") # r"W:\2023_tiles\2023tiles_landuse.arrow") # r"W:\2015_tiles\2015tiles_landuse.arrow")
#group this to a few simpler categories
ds["CATEGORY"] = ds["BBK_CAT"].map({v: k for k,vv in grouped_categories.items() for v in vv})
ds["CATEGORY"] = ds["CATEGORY"].astype("category")
ds["BBK_CAT"] = ds["BBK_CAT"].astype("category")
ds["BBK_CAT"]= ds["BBK_CAT"].cat.rename_categories(BBK_CATEGORIES.values())
ds[["path", "BBK_CAT","CATEGORY"]].sample(5)
Out[6]:
path BBK_CAT CATEGORY
292613 W:\2015_tiles\K127n\OGWRGB13_15VL_K127n_19_32.jpg Gras Struiken (Landbouw) GREEN
717788 W:\2015_tiles\K205n\OGWRGB13_15VL_K205n_14_68.jpg Gras Struiken (Landbouw) GREEN
942831 W:\2015_tiles\K243n\OGWRGB13_15VL_K243n_18_72.jpg Bomen (Groen) WOOD
774678 W:\2015_tiles\K214z\OGWRGB13_15VL_K214z_24_18.jpg Akker (Landbouw) AGRO
509148 W:\2015_tiles\K164n\OGWRGB13_15VL_K164n_11_39.jpg Akker (Landbouw) AGRO

Also check if files exist on drive and remove files that don't exist or are moro then half black (=nodata).

In [ ]:
ds = ds[ds["path"].apply(lambda x: os.path.exists(x) and np.median( imread(x) ) != 0 ) ]
ds.reset_index(inplace=True)

Calulate the percentage of occurrences of a CATEGORY

In [8]:
ax = (100* ds.CATEGORY.value_counts() / ds.CATEGORY.count()).plot(
    edgecolor='#fff', kind='bar',alpha=0.9, rot=0,                                                     
    color =['#0b0', '#ffff0e','#eee', '#964B00',  '#d62728',  '#0ff', '#f0f'])

ax.set_title('The percentage of occurrences of CATEGORY')
ax.set_xlabel(None)
ax.set_ylabel('%')
Out[8]:
Text(0, 0.5, '%')
No description has been provided for this image

We can use these inverse value counts as weigths on these tiles. So the less a values common a value, the more likely it will be picked in weigthed sampling.

In [9]:
weights = 1/ds["CATEGORY"].value_counts()
weights
Out[9]:
CATEGORY
GREEN       0.000002
AGRO        0.000002
WOOD        0.000005
ROADLIKE    0.000016
BUILDING    0.000021
WATER       0.000039
Name: count, dtype: float64
In [45]:
# set as WEIGHT for readability and performance convert integer 
ds['WEIGHT'] = ds['CATEGORY'].map( weights*10e6 ).astype('int32')
# you can use this field to sample this dataset in balanced manner, 
# replace is False so a photo won't be picked twice. 
ds.sample(5,  weights='WEIGHT', replace=False)[['path',"BBK_CAT","CATEGORY","WEIGHT"]]
Out[45]:
path BBK_CAT CATEGORY WEIGHT
179228 W:\2015_tiles\K085z\OGWRGB13_15VL_K085z_25_01.jpg Akker (Landbouw) AGRO 23
798295 W:\2015_tiles\K217z\OGWRGB13_15VL_K217z_39_01.jpg Akker (Landbouw) AGRO 23
884467 W:\2015_tiles\K233n\OGWRGB13_15VL_K233n_39_61.jpg Gras Struiken (Waterrand) WATER 387
1093354 W:\2015_tiles\K267z\OGWRGB13_15VL_K267z_10_38.jpg Akker (Landbouw) AGRO 23
86016 W:\2015_tiles\K065z\OGWRGB13_15VL_K065z_45_22.jpg Akker (Landbouw) AGRO 23
In [46]:
# since water is mostly featureless, lets lower its weight a bit more
ds.loc[ds.CATEGORY == 'WATER', 'WEIGHT'] = ds.loc[ds.CATEGORY == 'WATER', 'WEIGHT'] // 3

# as well for unknown
ds.loc[ds.CATEGORY == 'UNKNOWN', 'WEIGHT'] = ds.loc[ds.CATEGORY == 'UNKNOWN', 'WEIGHT'] // 3
In [47]:
ds.WEIGHT.unique()
Out[47]:
array([ 23,  46,  20, 160, 210, 129])
In [65]:
ds.sample(10,  weights='WEIGHT', replace=False)[['path',"BBK_CAT","CATEGORY","WEIGHT"]]
Out[65]:
path BBK_CAT CATEGORY WEIGHT
369338 W:\2015_tiles\K141n\OGWRGB13_15VL_K141n_43_05.jpg Akker (Landbouw) AGRO 23
1274885 W:\2015_tiles\K305z\OGWRGB13_15VL_K305z_40_04.jpg Bomen (Groen) WOOD 46
783289 W:\2015_tiles\K215z\OGWRGB13_15VL_K215z_38_49.jpg Gras Struiken (Groen) GREEN 20
872012 W:\2015_tiles\K231z\OGWRGB13_15VL_K231z_24_08.jpg Water WATER 129
474967 W:\2015_tiles\K157z\OGWRGB13_15VL_K157z_05_22.jpg Gebouwen BUILDING 210
255195 W:\2015_tiles\K118z\OGWRGB13_15VL_K118z_29_41.jpg Akker (Landbouw) AGRO 23
1129828 W:\2015_tiles\K282z\OGWRGB13_15VL_K282z_04_27.jpg Overig Afgedekt ROADLIKE 160
1205696 W:\2015_tiles\K294z\OGWRGB13_15VL_K294z_17_01.jpg Gras Struiken (Landbouw) GREEN 20
982207 W:\2015_tiles\K248n\OGWRGB13_15VL_K248n_43_58.jpg Gras Struiken (Landbouw) GREEN 20
582338 W:\2015_tiles\K175z\OGWRGB13_15VL_K175z_37_65.jpg Water WATER 129

Save the results back to an arrow file. Then repeat this flow for 2015 and 2023.

In [66]:
ds[['path',"BBK_CAT","CATEGORY","WEIGHT"]].to_feather('.\\data\\tiles_1989_weighted.arrow', compression='lz4')

Merging Results¶

The resulting arrow files saved are added to project.

I had about 60000 images from 1989 at a ground resolution of 1m, 1,65 million form 2015 at resolution 25cm and 2,62 million from 2023 at a resolution 15cm. While the resolution of 2023 is higher then 2015 the quality is lower. So I also reduce its importance a bit.

The dataset form 1989 is derived from analog color images, while those from 2015 and 2023 are taken digitally.

In [3]:
df1989 = pd.read_feather("data\\tiles_1989_weighted.arrow")
df1989.sample(4,  weights='WEIGHT')
Out[3]:
path BBK_CAT CATEGORY WEIGHT
17861 W:\1989_tiles\OKZRGB79_90VL_K16\OKZRGB79_90VL_... Gras Struiken (Groen) GREEN 5
58613 W:\1989_tiles\OKZRGB79_90VL_K39\OKZRGB79_90VL_... Akker (Landbouw) AGRO 4
45809 W:\1989_tiles\OKZRGB79_90VL_K29\OKZRGB79_90VL_... Autowegen ROADLIKE 69
21418 W:\1989_tiles\OKZRGB79_90VL_K17\OKZRGB79_90VL_... Overig Afgedekt ROADLIKE 69
In [4]:
df2015 = pd.read_feather("data\\tiles_2015_weighted.arrow")
df2015.sample(4, weights='WEIGHT')
Out[4]:
path BBK_CAT CATEGORY WEIGHT
819860 W:\2015_tiles\K222z\OGWRGB13_15VL_K222z_27_38.jpg Akker (Landbouw) AGRO 23
571965 W:\2015_tiles\K174n\OGWRGB13_15VL_K174n_48_66.jpg Spoorwegen ROADLIKE 160
1455586 W:\2015_tiles\K336n\OGWRGB13_15VL_K336n_04_57.jpg Gebouwen BUILDING 210
386341 W:\2015_tiles\K143z\OGWRGB13_15VL_K143z_21_04.jpg Akker (Landbouw) AGRO 23
In [5]:
df2023 = pd.read_feather("data\\tiles_2023_weighted.arrow")
df2023.sample(4, weights='WEIGHT')
Out[5]:
index path BBK_CAT CATEGORY WEIGHT
2038408 2281336 W:\2023_tiles\K327n\OMWRGBMRVL_K327n_58_80.jpg Akker (Landbouw) AGRO 16
1709589 1933173 W:\2023_tiles\K293n\OMWRGBMRVL_K293n_14_01.jpg Overig Afgedekt ROADLIKE 52
887226 1041295 W:\2023_tiles\K187n\OMWRGBMRVL_K187n_57_14.jpg Akker (Landbouw) AGRO 16
617578 760235 W:\2023_tiles\K157n\OMWRGBMRVL_K157n_08_15.jpg Water WATER 38
In [6]:
df2023.WEIGHT = df2023.WEIGHT //2
df1989.WEIGHT = df2023.WEIGHT *30
In [7]:
df = pd.concat([df1989, df2015, df2023], ignore_index=True)
In [9]:
df[["path","BBK_CAT","CATEGORY","WEIGHT"]].sample(25,  weights='WEIGHT', replace=False)
Out[9]:
path BBK_CAT CATEGORY WEIGHT
2574101 W:\2023_tiles\K231z\OMWRGBMRVL_K231z_27_21.jpg Gebouwen BUILDING 44
29305 W:\1989_tiles\OKZRGB79_90VL_K21\OKZRGB79_90VL_... Akker (Landbouw) AGRO 1320
494807 W:\2015_tiles\K173z\OGWRGB13_15VL_K173z_04_54.jpg Overig Afgedekt ROADLIKE 160
291307 W:\2015_tiles\K136z\OGWRGB13_15VL_K136z_43_32.jpg Bomen (Groen) WOOD 46
1853191 W:\2023_tiles\K146n\OMWRGBMRVL_K146n_59_30.jpg Water WATER 19
2659380 W:\2023_tiles\K238z\OMWRGBMRVL_K238z_20_61.jpg Akker (Landbouw) AGRO 8
657942 W:\2015_tiles\K212z\OGWRGB13_15VL_K212z_07_56.jpg Gras Struiken (Landbouw) GREEN 20
810922 W:\2015_tiles\K236z\OGWRGB13_15VL_K236z_48_78.jpg Gras Struiken (Groen) GREEN 20
28133 W:\1989_tiles\OKZRGB79_90VL_K21\OKZRGB79_90VL_... Akker (Landbouw) AGRO 570
1264204 W:\2015_tiles\K334n\OGWRGB13_15VL_K334n_32_44.jpg Bomen (Groen) WOOD 46
448829 W:\2015_tiles\K165n\OGWRGB13_15VL_K165n_34_68.jpg Bomen (Groen) WOOD 46
1713048 W:\2023_tiles\K131z\OMWRGBMRVL_K131z_51_71.jpg Gebouwen BUILDING 44
3588204 W:\2023_tiles\K413n\OMWRGBMRVL_K413n_09_33.jpg Akker (Landbouw) AGRO 8
187947 W:\2015_tiles\K096n\OGWRGB13_15VL_K096n_09_06.jpg Bomen (Groen) WOOD 46
518381 W:\2015_tiles\K176z\OGWRGB13_15VL_K176z_19_33.jpg Gras Struiken (Landbouw) GREEN 20
2054616 W:\2023_tiles\K166z\OMWRGBMRVL_K166z_48_59.jpg Gebouwen BUILDING 44
2590652 W:\2023_tiles\K233n\OMWRGBMRVL_K233n_08_39.jpg Water WATER 19
2737742 W:\2023_tiles\K247n\OMWRGBMRVL_K247n_06_15.jpg Gras Struiken (Groen) GREEN 6
940628 W:\2015_tiles\K258n\OGWRGB13_15VL_K258n_30_61.jpg Bomen (Groen) WOOD 46
538364 W:\2015_tiles\K181n\OGWRGB13_15VL_K181n_40_35.jpg Bomen (Groen) WOOD 46
12086 W:\1989_tiles\OKZRGB79_90VL_K13\OKZRGB79_90VL_... Akker (Landbouw) AGRO 390
16443 W:\1989_tiles\OKZRGB79_90VL_K15\OKZRGB79_90VL_... Water WATER 270
2170028 W:\2023_tiles\K178n\OMWRGBMRVL_K178n_36_37.jpg Autowegen ROADLIKE 26
675418 W:\2015_tiles\K214z\OGWRGB13_15VL_K214z_39_60.jpg Gras Struiken (Groen) GREEN 20
3511 W:\1989_tiles\OKZRGB79_90VL_K07\OKZRGB79_90VL_... Bomen (Groen) WOOD 390
In [10]:
df.to_feather("data\\tiles_merged.arrow",  compression='lz4')

DATASET¶

In [64]:
def makeWeightedDataFromArrow(arrow, train_size=4000, test_size=1000,
                    pathField='NAME', weightField='WEIGHT', replacement=False):
    ds= pd.read_feather(arrow)
    train_paths = ds.sample(train_size, weights=weightField, replace=replacement)[pathField]
    test_paths = ds.sample(test_size, weights=weightField, replace=replacement)[pathField]
    return list(train_paths), list(test_paths)
In [65]:
train_paths, val_paths = makeWeightedDataFromArrow(
    r'.\data\tiles_2015_Weighted.arrow', train_size=160, test_size=40, 
    pathField='path', weightField='WEIGHT', replacement=False)
print(len(train_paths), len(val_paths))
160 40

Augmentation¶

In [7]:
from skimage.exposure import adjust_gamma, adjust_sigmoid
from skimage.util import random_noise
from skimage.filters import gaussian
from skimage.transform import resize

def grainify(img:np.ndarray):
    "Make it grainy like an old photo"
    c, rows, cols = img.shape
    val = np.random.uniform(0.036, 0.107)**2 

    # Full resolution
    noise_1 = np.zeros((rows, cols))
    noise_1 = random_noise(noise_1, mode='gaussian', var=val, clip=False)

    # # Half resolution
    noise_2 = np.zeros((rows//2, cols//2))
    noise_2 = random_noise(noise_2, mode='gaussian', var=(val*2)**2, clip=False)  
    noise_2 = resize(noise_2, (rows, cols))  # Upscale to original image size

    noise = noise_1 + noise_2 
    noise = np.stack( [noise]*c, axis=0)
    
    noisy_img = img/255 + noise # Add noise_im to the input image.
    return np.round((255 * noisy_img)).clip(0, 255).astype(np.uint8)

def aug(img:np.ndarray):
    "some data augmentation on img "
    img = adjust_gamma(img, gamma=np.random.uniform(low=0.5, high=1.5) ) #change lighting
    img = grainify(img)   #Make it grainy like an old photo
    img = adjust_sigmoid(img, gain= np.random.uniform(1,10) ) #contrast
    img = gaussian(img, np.random.uniform(0,1.5), channel_axis=0 ) # blur
    return (img*255).astype('uint8')

# Load random image using GDAL
imagePath = str( Path(r"W:\1979_tiles\OKZRGB79_90VL_K23\OKZRGB79_90VL_K23_0_9.png") )
image = gdal.Open(imagePath).ReadAsArray(buf_xsize=512, buf_ysize=512) #1pix == 1 meter

# Apply the transformation to your image
aug_image = aug(image)

fig,axs= plt.subplots(1,3)
axs[0].imshow( image.transpose((1, 2, 0)) )
axs[0].set_title('Original')
axs[1].imshow( aug_image.transpose((1, 2, 0))  )
axs[1].set_title('Random augmentation')
axs[2].imshow(  rgb2gray(aug_image , channel_axis=0 ) , cmap='Greys_r'  )
axs[2].set_title('Grayscale')
fig.set_size_inches(14,42)
fig.tight_layout()
No description has been provided for this image

Creating a iterable dataset-object from the sourcedata that is readable as a tensor in pytorch¶

In [67]:
class ColorizationDataset(Dataset):
    def __init__(self, paths:Iterable[os.PathLike], imsize:int=256, rootDir:str='', resize:bool=True):
        super().__init__()
        self.size = imsize
        self.paths = paths
        self.root = rootDir
        self.resize = resize

    def __getitem__(self, idx:int):
        imagePath = str(Path(self.root) / self.paths[idx]) 
        img = gdal.Open(imagePath).ReadAsArray(buf_xsize=self.size, buf_ysize=self.size) 
        img = aug(img)                              
        img_lab = rgb2lab( img , channel_axis=0 ) # Converting RGB to L*a*b
        img_lab = torch.tensor(img_lab, dtype=torch.float32) # Convert to Tensor
       
        L =  (img_lab[0] / 50. - 1).unsqueeze(0) # Between -1 and 1
        ab = img_lab[1:3] / 110  # Between -1 and 1
        return {'L': L, 'ab': ab}
    
    def __len__(self):
        return len(self.paths)
In [69]:
pretrain_dl = DataLoader(ColorizationDataset(train_paths, imsize=256), batch_size=8)
val_dl      = DataLoader(ColorizationDataset(val_paths, imsize=256), batch_size=4)
In [70]:
data = next(iter(val_dl))
Ls, abs_ = data['L'], data['ab']

print(Ls.shape, abs_.shape)
print(len(pretrain_dl), len(val_dl), len(train_paths), len(val_paths))
torch.Size([4, 1, 256, 256]) torch.Size([4, 2, 256, 256])
20 10 160 40

Generator¶

The generator is the main modal that wil do classifaction task to calcute the ab-values from a panchromatic image.
I use an existing model "resnet18" that we can donwload with TIMM and convert it to a Dynamic UNet with fast.ai.

In [3]:
from fastai.vision.learner import create_body
from fastai.vision.models.unet import DynamicUnet
import timm 

def ResUnet(n_input=1, n_output=2, size=224, timm_model_name='resnet18'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = timm.create_model(timm_model_name, pretrained=True)
    body = create_body(model, pretrained=True, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G

Pretrain Generator¶

You can pretrain the generator a little by running a few image trough it,
but without a discriminator this will not deliver good results.

In [4]:
class statsMeter:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3
    
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count
In [5]:
def pretrain_generator(net_G, pretrain_dl, epochs, lrate=1e-3):
    opt = optim.Adam(net_G.parameters(), lr=lrate)
    criterion = nn.L1Loss() 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Started pretraining at {datetime.datetime.now()}")
    for e in range(epochs):
        loss_meter = statsMeter()
        for data in pretrain_dl:
            L, ab = data['L'].to(device), data['ab'].to(device)
            preds = net_G(L)
            loss = criterion(preds, ab)
            opt.zero_grad()
            loss.backward()
            opt.step()
            loss_meter.update(loss.item(), L.size(0))
            
        print(f"Epoch {e + 1}/{epochs}")
        print(f"L1 Loss: {loss_meter.avg:.5f}")

    print(f"Finished pretraining at {datetime.datetime.now()}")
    
    return net_G
In [95]:
net_G = ResUnet(n_input=1, n_output=2, size=256)
net_G = pretrain_generator(net_G, pretrain_dl, epochs=5)
torch.save(net_G.state_dict(), "runs\\demo\\res18-unet_demo.pt")
Started pretraining at 2024-01-03 11:32:33.917362
Epoch 1/5
L1 Loss: 5204.15472
Epoch 2/5
L1 Loss: 0.02970
Epoch 3/5
L1 Loss: 0.00766
Epoch 4/5
L1 Loss: 0.00620
Epoch 5/5
L1 Loss: 0.00640
Finished pretraining at 2024-01-03 11:33:11.607399

Test Unet before GAN training¶

We see it produces image that look kind of sepia.

In [96]:
img_test = list( Path(r"W:\testdata\tiles_1950_gray").glob('*.png') )
randImg = lambda: str( img_test[ np.random.randint(0, len(img_test) ) ]) 

resunet = ResUnet()
state_dict = torch.load(Path(".\\runs\\demo\\res18-unet_demo.pt"), map_location=device)
resunet.load_state_dict(state_dict)
Out[96]:
<All keys matched successfully>
In [97]:
from model.tools import lab_to_rgb

testSet  = []
for i in range(4):
   ds =gdal.Open( randImg())
   b1 = torch.Tensor( ( ds.GetRasterBand(1).ReadAsArray() /128) -1 ).unsqueeze(0)
   testSet.append(b1.unsqueeze(0))

testSet = torch.cat(testSet)

f, axs= plt.subplots(4,2)
axs[0][0].imshow(testSet[0][0], cmap='Greys_r')
axs[1][0].imshow(testSet[1][0], cmap='Greys_r')
axs[2][0].imshow(testSet[2][0], cmap='Greys_r')
axs[3][0].imshow(testSet[3][0], cmap='Greys_r')

with torch.inference_mode():
   w= resunet(testSet.to(device))

colorized = lab_to_rgb(testSet, w.cpu())
axs[0][1].imshow(colorized[0])
axs[1][1].imshow(colorized[1])
axs[2][1].imshow(colorized[2])
axs[3][1].imshow(colorized[3])

f.set_size_inches(4,8)
f.tight_layout()
No description has been provided for this image

Patch Discriminator¶

This is the part of the GAN-model that wil act as the Adversary, and should become much better at distinguishing between real end colorize images then a regular L1loss like we used in the pretraining.

In [98]:
class PatchDiscriminator(nn.Module):
    def __init__(self, input_c, num_filters=64, n_down=3):
        super().__init__()
        model = [self.get_layers(input_c, num_filters, norm=False)]
        model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2) 
                          for i in range(n_down)]                                     
        model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] 
        self.model = nn.Sequential(*model)                                                   
        
    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): 
        layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]  
        if norm: layers += [nn.BatchNorm2d(nf)]
        if act: layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

GAN Loss¶

Use the output of the discriminator model calculate a Loss for the generator. The Loss of the GAN's generator calculated by a binary cross-entropy loss between the discriminator's output and the label data.

In [112]:
class GANLoss(nn.Module):
    def __init__(self, real_label=1, fake_label=0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        self.loss = nn.BCEWithLogitsLoss()
    
    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)
    
    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

Main Model: bringing it all together¶

This is the final A Generative Adversarial Network (GAN) cosisting of both the U-net generator and the discriminator.

In [113]:
class MainModel(nn.Module):
    def __init__(self, lr_G=2e-4, lr_D=2e-4, 
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1

        self.net_G = ResUnet().to(self.device)   # Generator
        self.net_D = self.init_weights(          # Discriminator
            PatchDiscriminator(input_c=3, n_down=3, num_filters=64)).to(self.device) 
        self.GANcriterion = GANLoss(torch.float32).to(self.device)
        self.L1criterion =  nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
    
    def init_weights(self, net, gain:float=0.02):
        def init_func(m):
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and 'Conv' in classname:
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)
            elif 'BatchNorm2d' in classname:
                nn.init.normal_(m.weight.data, 1., gain)
                nn.init.constant_(m.bias.data, 0.)

        net.apply(init_func)
        return net

    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad
        
    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)
        
    def forward(self):
        self.fake_color = self.net_G(self.L)
    
    def backward_D(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()
    
    def backward_G(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()
    
    def optimize(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()
        
        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

Train¶

Utils¶

Store stats, covert CIELAB data to RGB, visualize results.

In [124]:
def create_loss_meters():
    loss_D_fake = statsMeter()
    loss_D_real = statsMeter()
    loss_D = statsMeter()
    loss_G_GAN = statsMeter()
    loss_G_L1 = statsMeter()
    loss_G = statsMeter()
    
    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """
    
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)
    
def visualize(model, data):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    for i in range(4):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")
    plt.show()
        
def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.4f}")

Training¶

In [122]:
def train_model(model, train_dl, epochs):
    model.train() 
     # getting a batch for visualizing the model output after fixed intrvals
    for e in range(epochs):
        loss_meter_dict = create_loss_meters() # function returing a dictionary of objects to 
                          
        for data in train_dl:
            model.setup_input(data) 
            model.optimize()
            batchsize = data['L'].size(0)
            update_losses(model, loss_meter_dict, count=batchsize) # function updating the log objects

        print(f"\nEpoch {e+1}/{epochs}") 
        log_results(loss_meter_dict) # function to print out the losses
        test_data = next(iter(val_dl))
        visualize(model, test_data) # function displaying the model's outputs
In [120]:
model = MainModel()
In [125]:
train_model(model, pretrain_dl, 5)
model.eval()
torch.save(model.state_dict(), 'runs\\models\\testRun.pth')
Epoch 1/5
loss_D_fake: 0.4419
loss_D_real: 0.4319
loss_D: 0.4369
loss_G_GAN: 1.3350
loss_G_L1: 4.6299
loss_G: 5.9649
No description has been provided for this image
Epoch 2/5
loss_D_fake: 0.4540
loss_D_real: 0.4412
loss_D: 0.4476
loss_G_GAN: 1.3924
loss_G_L1: 4.9903
loss_G: 6.3826
No description has been provided for this image
Epoch 3/5
loss_D_fake: 0.6094
loss_D_real: 0.5500
loss_D: 0.5797
loss_G_GAN: 1.2045
loss_G_L1: 4.9443
loss_G: 6.1488
No description has been provided for this image
Epoch 4/5
loss_D_fake: 0.5511
loss_D_real: 0.5669
loss_D: 0.5590
loss_G_GAN: 1.1266
loss_G_L1: 4.6503
loss_G: 5.7769
No description has been provided for this image
Epoch 5/5
loss_D_fake: 0.5397
loss_D_real: 0.5357
loss_D: 0.5377
loss_G_GAN: 1.1587
loss_G_L1: 4.5650
loss_G: 5.7236
No description has been provided for this image

Training in production:¶

In the example above I used Resnet18, in reality I used the much larger resnet34 model.

To reduce memory-use I used Half-precision floating-point tensors and I also used Huggingface accelerate to speed up training. This could also allows us to train on multiple GPU's.

I trained for about 12 hours on a single RTX-3080 NIVIA GPU.

Inference¶

I use the results of the training to colorize real panchromatic images, not colored image, not color image made greyscale.

The proof of the pudding is in the eating.

In [ ]:
import torch, numpy as np, matplotlib.pyplot as plt
from osgeo import gdal
gdal.UseExceptions()
from model.unet import ResUnet
from model.tools import lab_to_rgb

device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")
In [3]:
W= ".\\runs\\models\\run32\\color_run32_resnet34_512_net_G40.pth"
model= ResUnet( size=512 , timm_model_name='resnet50') 
model.to(device)
model.load_state_dict( torch.load(W, map_location=device ) )
Out[3]:
<All keys matched successfully>

We use an JPEG-2000 image from 1970 obtained from the Flemish Government and Geotiff from 1948 obtained from the NGI.

Both image have a ground resolation of about 1 meter.

We let numpy pick a random location of 512 by 512 pixels

In [ ]:
imsize = 512
# two large scale black&white orthophoto mosaics of southern Antwerp: 
ds0 = gdal.Open( "W:\\1970\\OKZPAN71VL_K15.jp2" )      # 1970  1m resolution
ds1 = gdal.Open( "W:\\1948-1968\\1948antwZuid.tif" )   # 1948  1m resolution

# pick a random place on the photo's of 512x512 pixels
ds0_img = ds0.GetRasterBand(1).ReadAsArray(
    xoff= np.random.randint(ds0.RasterXSize - imsize), 
    yoff= np.random.randint(ds0.RasterYSize - imsize), 
    win_xsize=imsize, win_ysize=imsize) 

ds1_img = ds1.GetRasterBand(1).ReadAsArray(
    xoff= np.random.randint(ds1.RasterXSize - imsize), 
    yoff= np.random.randint(ds1.RasterYSize - imsize), 
    win_xsize=imsize, win_ysize=imsize) 
In [12]:
g0_img = torch.Tensor( ( ds0_img /128) -1 ).unsqueeze(0)
g1_img = torch.Tensor( ( ds1_img /128) -1 ).unsqueeze(0)

with torch.inference_mode():
    pred0 = model(g0_img.unsqueeze(0).to(device) )
    pred1 = model(g1_img.unsqueeze(0).to(device) )
 
colorized0 = lab_to_rgb(g0_img.unsqueeze(0), pred0.cpu())[0]
colorized1 = lab_to_rgb(g1_img.unsqueeze(0), pred1.cpu())[0]

f , axs = plt.subplots(2,2)
axs[0,0].imshow(g0_img[0], cmap='Greys_r')
axs[0,0].set_title('input')
axs[0,1].imshow(colorized0)
axs[0,1].set_title("colorized")
axs[1,0].imshow(g1_img[0], cmap='Greys_r')
axs[1,1].imshow(colorized1)
f.set_size_inches(15,15)
f.tight_layout()
No description has been provided for this image